import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from tqdm import tqdm


def auto_detect_task_type(y):
    """
    Auto-detect whether the task is classification or regression based on the target variable.
    
    Parameters:
        y: Target variable array
    
    Returns:
        str: "classification" or "regression"
    """
    unique_values = np.unique(y)
    n_unique = len(unique_values)
    
    # Check if it's binary classification
    if n_unique == 2 and np.all(np.isin(unique_values, [0, 1])):
        return "classification"
    
    # Check if it's multi-class classification (small number of integer values)
    if n_unique <= 20 and np.all(unique_values == unique_values.astype(int)):
        return "classification"
    
    # Otherwise, assume regression
    return "regression"

def progressive_interaction_evaluation(X_train, X_test, y_train, y_test, results_df, max_interactions=100, no_linear_term=True, task_type="classification"):
    """
    Progressively add interactions to linear models and evaluate performance.
    
    Parameters:
        X_train, X_test: Training and test feature matrices
        y_train, y_test: Training and test target vectors
        results_df: DataFrame with interaction rankings (from get_improved_interactions)
        max_interactions: Maximum number of interactions to test (None = all)
        no_linear_term: If True, exclude linear terms and use only interaction terms + bias
        task_type: "classification" or "regression" - determines model type and metrics
    
    Returns:
        metrics_df: DataFrame with performance metrics for each step
        models: List of fitted models
    """
    
    # Initialize results storage
    metrics_list = []
    models = []
    
    # Prepare base features (standardized for better linear model performance)
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # Determine number of interactions to test
    n_interactions = len(results_df) if max_interactions is None else min(max_interactions, len(results_df))
    
    print(f"Testing {n_interactions} interaction terms progressively for {task_type}...")
    
    # Loop through interactions progressively
    for step in tqdm(range(n_interactions + 1), desc="Fitting models"):
        
        if step == 0:
            # Base model with no interactions
            if no_linear_term:
                # Start with empty feature matrix (only bias will be added by the model)
                X_train_current = np.empty((X_train_scaled.shape[0], 0))
                X_test_current = np.empty((X_test_scaled.shape[0], 0))
            else:
                # Include linear terms
                X_train_current = X_train_scaled.copy()
                X_test_current = X_test_scaled.copy()
            interaction_names = []
        else:
            # Add interaction terms up to current step
            if no_linear_term:
                # Start with empty feature matrix (no linear terms)
                X_train_current = np.empty((X_train_scaled.shape[0], 0))
                X_test_current = np.empty((X_test_scaled.shape[0], 0))
            else:
                # Include linear terms
                X_train_current = X_train_scaled.copy()
                X_test_current = X_test_scaled.copy()
            interaction_names = []
            
            for i in range(step):
                row = results_df.iloc[i]
                feat_i_idx = int(row['i'])
                feat_j_idx = int(row['j'])
                
                # Create interaction term
                interaction_train = X_train_scaled[:, feat_i_idx] * X_train_scaled[:, feat_j_idx]
                interaction_test = X_test_scaled[:, feat_i_idx] * X_test_scaled[:, feat_j_idx]
                
                # Add to feature matrices
                if X_train_current.shape[1] == 0:
                    # First interaction term
                    X_train_current = interaction_train.reshape(-1, 1)
                    X_test_current = interaction_test.reshape(-1, 1)
                else:
                    # Additional interaction terms
                    X_train_current = np.column_stack([X_train_current, interaction_train])
                    X_test_current = np.column_stack([X_test_current, interaction_test])
                
                # Track interaction names
                feat_i_name = row.get('feature_i', f'x_{feat_i_idx}')
                feat_j_name = row.get('feature_j', f'x_{feat_j_idx}')
                interaction_names.append(f'{feat_i_name}*{feat_j_name}')
        
        # Fit model based on task type
        if X_train_current.shape[1] == 0:
            # Handle special case: no features at all (step=0 with no_linear_term=True)
            from sklearn.dummy import DummyClassifier, DummyRegressor
            if task_type == "classification":
                model = DummyClassifier(strategy='most_frequent')
            else:  # regression
                model = DummyRegressor(strategy='mean')
            model.fit(X_train_current, y_train)
        else:
            if task_type == "classification":
                model = LogisticRegression(random_state=42, max_iter=1000)
            else:  # regression
                from sklearn.linear_model import LinearRegression
                model = LinearRegression()
            model.fit(X_train_current, y_train)
        models.append(model)
        
        # Make predictions and calculate metrics based on task type
        if task_type == "classification":
            # Classification metrics
            y_pred_binary = model.predict(X_test_current)
            
            # Handle probability predictions (DummyClassifier may not have proper probabilities)
            if hasattr(model, 'predict_proba') and X_train_current.shape[1] > 0:
                y_pred_proba = model.predict_proba(X_test_current)[:, 1]
            else:
                # For dummy classifier or edge cases, use binary predictions as probabilities
                y_pred_proba = y_pred_binary.astype(float)
            
            # Calculate classification metrics
            try:
                auc = roc_auc_score(y_test, y_pred_proba)
            except:
                auc = np.nan
                
            f1 = f1_score(y_test, y_pred_binary)
            accuracy = accuracy_score(y_test, y_pred_binary)
            precision = precision_score(y_test, y_pred_binary, zero_division=0)
            recall = recall_score(y_test, y_pred_binary, zero_division=0)
            
            # Store classification results
            metrics_list.append({
                'step': step,
                'n_interactions': step,
                'n_features': X_train_current.shape[1],
                'auc': auc,
                'f1': f1,
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'interactions_added': interaction_names.copy() if step > 0 else []
            })
            
        else:  # regression
            # Regression predictions and metrics
            y_pred = model.predict(X_test_current)
            
            # Calculate regression metrics
            from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
            
            mse = mean_squared_error(y_test, y_pred)
            rmse = np.sqrt(mse)
            mae = mean_absolute_error(y_test, y_pred)
            r2 = r2_score(y_test, y_pred)
            
            # Calculate MAPE (handle division by zero)
            def mean_absolute_percentage_error(y_true, y_pred):
                mask = y_true != 0
                if np.sum(mask) == 0:
                    return np.nan
                return np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
            
            mape = mean_absolute_percentage_error(y_test, y_pred)
            
            # Store regression results
            metrics_list.append({
                'step': step,
                'n_interactions': step,
                'n_features': X_train_current.shape[1],
                'r2': r2,
                'rmse': rmse,
                'mae': mae,
                'mse': mse,
                'mape': mape,
                'interactions_added': interaction_names.copy() if step > 0 else []
            })
    
    return pd.DataFrame(metrics_list)


def plot_performance_curves(metrics_df, metric='auc', figsize=(8, 6)):
    """
    Plot performance curve for a specific metric showing how it changes as interactions are added.
    
    Parameters:
        metrics_df: DataFrame from progressive_interaction_evaluation
        metric: Specific metric to plot 
                Classification: ('auc', 'f1', 'accuracy', 'precision', 'recall')
                Regression: ('r2', 'rmse', 'mae', 'mse', 'mape')
        figsize: Figure size tuple
    """
    
    # Define available metrics for both task types
    classification_metrics = ['auc', 'f1', 'accuracy', 'precision', 'recall']
    regression_metrics = ['r2', 'rmse', 'mae', 'mse', 'mape']
    all_metrics = classification_metrics + regression_metrics
    
    # Validate metric parameter
    if metric not in all_metrics:
        raise ValueError(f"Metric must be one of {all_metrics}, got '{metric}'")
    
    # Check if metric exists in the dataframe
    if metric not in metrics_df.columns:
        available_cols = [col for col in all_metrics if col in metrics_df.columns]
        raise ValueError(f"Metric '{metric}' not found in dataframe. Available metrics: {available_cols}")
    
    # Create single plot
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    
    # Color mapping for different metrics
    color_map = {
        # Classification metrics
        'auc': '#1f77b4',
        'f1': '#ff7f0e', 
        'accuracy': '#2ca02c',
        'precision': '#d62728',
        'recall': '#9467bd',
        # Regression metrics
        'r2': '#8c564b',
        'rmse': '#e377c2',
        'mae': '#7f7f7f',
        'mse': '#bcbd22',
        'mape': '#17becf'
    }
    
    color = color_map.get(metric, '#1f77b4')  # Default color if not found
    
    # Plot the specified metric
    ax.plot(metrics_df['step'], metrics_df[metric], 
           marker='o', linewidth=2, markersize=6, color=color, alpha=0.8)
    ax.set_xlabel('Number of Interactions Added', fontsize=12)
    ax.set_ylabel(metric.upper(), fontsize=12)
    ax.set_title(f'{metric.upper()} vs Number of Interaction Terms', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    
    # Set y-axis limits based on metric type
    if metric in classification_metrics:
        ax.set_ylim(0, 1)
    elif metric == 'r2':
        # R² can be negative for very poor models, but typically ranges 0-1
        y_min = min(0, metrics_df[metric].min() * 1.1)
        ax.set_ylim(y_min, 1)
    elif metric == 'mape':
        # MAPE is percentage, so set reasonable upper limit
        ax.set_ylim(0, max(100, metrics_df[metric].max() * 1.1))
    else:
        # For error metrics (RMSE, MAE, MSE), start from 0
        ax.set_ylim(0, metrics_df[metric].max() * 1.1)
    
    # Add some styling
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(0.5)
    ax.spines['bottom'].set_linewidth(0.5)
    
    plt.tight_layout()
    plt.show()
    
    return fig


def plot_multiple_performance_curves(metrics_dfs, names, metric='auc', figsize=(10, 6), plot_steps=None):
    """
    Plot multiple performance curves in a single figure for comparison.
    
    Parameters:
        metrics_dfs: List of DataFrames from progressive_interaction_evaluation
        names: List of names for each curve (same length as metrics_dfs)
        metric: Specific metric to plot 
                Classification: ('auc', 'f1', 'accuracy', 'precision', 'recall')
                Regression: ('r2', 'rmse', 'mae', 'mse', 'mape')
        figsize: Figure size tuple
        plot_steps: Maximum number of steps to plot on x-axis (None = plot all steps)
    """
    
    # Validate inputs
    if len(metrics_dfs) != len(names):
        raise ValueError(f"Number of DataFrames ({len(metrics_dfs)}) must match number of names ({len(names)})")
    
    # Define available metrics for both task types
    classification_metrics = ['auc', 'f1', 'accuracy', 'precision', 'recall']
    regression_metrics = ['r2', 'rmse', 'mae', 'mse', 'mape']
    all_metrics = classification_metrics + regression_metrics
    
    if metric not in all_metrics:
        raise ValueError(f"Metric must be one of {all_metrics}, got '{metric}'")
    
    # Check if metric exists in all dataframes
    for i, df in enumerate(metrics_dfs):
        if metric not in df.columns:
            available_cols = [col for col in all_metrics if col in df.columns]
            raise ValueError(f"Metric '{metric}' not found in dataframe {i} ({names[i]}). Available metrics: {available_cols}")
    
    # Create single plot
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    
    # Color palette for multiple curves
    colors = plt.cm.Set1(np.linspace(0, 1, len(metrics_dfs)))
    
    # Plot each curve
    for i, (metrics_df, name) in enumerate(zip(metrics_dfs, names)):
        # Filter data based on plot_steps if specified
        if plot_steps is not None:
            plot_data = metrics_df[metrics_df['step'] <= plot_steps]
        else:
            plot_data = metrics_df
            
        ax.plot(plot_data['step'], plot_data[metric], 
               marker='o', linewidth=2, markersize=4, 
               color=colors[i], alpha=0.8, label=name)
    
    # Styling
    ax.set_xlabel('Number of Interactions Added', fontsize=12)
    ax.set_ylabel(metric.upper(), fontsize=12)
    ax.set_title(f'{metric.upper()} Comparison: Multiple Methods', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    
    # Set y-axis limits based on metric type
    all_values = np.concatenate([df[metric].values for df in metrics_dfs])
    
    if metric in classification_metrics:
        ax.set_ylim(0, 1)
    elif metric == 'r2':
        # R² can be negative for very poor models, but typically ranges 0-1
        y_min = min(0, np.min(all_values) * 1.1)
        ax.set_ylim(y_min, 1)
    elif metric == 'mape':
        # MAPE is percentage, so set reasonable upper limit
        ax.set_ylim(0, max(100, np.max(all_values) * 1.1))
    else:
        # For error metrics (RMSE, MAE, MSE), start from 0
        ax.set_ylim(0, np.max(all_values) * 1.1)
    
    # Set x-axis limits based on plot_steps
    if plot_steps is not None:
        ax.set_xlim(0, plot_steps)
    
    # Add legend
    ax.legend(loc='best', frameon=True, fancybox=True, shadow=True)
    
    # Clean styling
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(0.5)
    ax.spines['bottom'].set_linewidth(0.5)
    
    plt.tight_layout()
    plt.show()
    
    # Print comparison summary
    print(f"\n{metric.upper()} Comparison Summary:")
    print("-" * 50)
    for i, (metrics_df, name) in enumerate(zip(metrics_dfs, names)):
        base_score = metrics_df[metric].iloc[0]
        
        # For regression metrics, determine if higher or lower is better
        if metric in ['rmse', 'mae', 'mse', 'mape']:
            # Lower is better for error metrics
            best_score = metrics_df[metric].min()
            best_step = metrics_df[metric].idxmin()
            improvement = base_score - best_score  # Positive improvement means reduction in error
            improvement_sign = "-" if improvement > 0 else "+"
        else:
            # Higher is better for R² and classification metrics
            best_score = metrics_df[metric].max()
            best_step = metrics_df[metric].idxmax()
            improvement = best_score - base_score
            improvement_sign = "+" if improvement > 0 else ""
        
        print(f"{name}:")
        print(f"  Base: {base_score:.4f} | Best: {best_score:.4f} @ step {best_step} | Improvement: {improvement_sign}{abs(improvement):.4f}")
    
    return fig



def three_stage_interaction_evaluation(X_train, X_test, y_train, y_test, results_df, max_interactions=100, task_type="auto"):
    """
    Evaluate interactions using three-stage comparison:
    1. No features (baseline with only other features)
    2. Linear features (i + j features individually) 
    3. Linear + interaction (i + j + i*j)
    
    For each interaction i,j:
    - g1 = performance gain from adding linear i,j features
    - g2 = performance gain from adding i*j interaction on top of linear i,j
    - Returns g1, g2, and their difference/ratio
    
    Parameters:
        X_train, X_test: Training and test feature matrices
        y_train, y_test: Training and test target vectors
        results_df: DataFrame with interaction rankings (from get_improved_interactions)
        max_interactions: Maximum number of interactions to test (None = all)
        task_type: "classification", "regression", or "auto" - determines model type and metrics
    
    Returns:
        comparison_df: DataFrame with g1, g2, difference, and ratio for each interaction
        all_metrics_df: Detailed metrics for all models fitted
    """
    
    # Initialize results storage
    comparison_list = []
    all_metrics_list = []
    
    # Prepare base features (standardized)
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # Determine number of interactions to test
    n_interactions = len(results_df) if max_interactions is None else min(max_interactions, len(results_df))
    
    # Auto-detect task type if needed
    if task_type == "auto":
        task_type = auto_detect_task_type(y_train)
        print(f"Auto-detected task type: {task_type}")
    elif task_type == "classification":
        # Check if labels are actually continuous
        unique_labels = np.unique(y_train)
        if len(unique_labels) > 10 or not np.all(np.isin(unique_labels, [0, 1])):
            print(f"Warning: task_type='classification' but labels appear continuous. Unique values: {len(unique_labels)}")
            print("Consider using task_type='regression' for continuous labels.")
    
    print(f"Testing {n_interactions} interactions with three-stage evaluation for {task_type}...")
    
    # Primary metric to use for comparison
    primary_metric = 'auc' if task_type == "classification" else 'r2'
    
    # Loop through each interaction
    for idx in tqdm(range(n_interactions), desc="Evaluating interactions"):
        row = results_df.iloc[idx]
        feat_i_idx = int(row['i'])
        feat_j_idx = int(row['j'])
        feat_i_name = row.get('feature_i', f'x_{feat_i_idx}')
        feat_j_name = row.get('feature_j', f'x_{feat_j_idx}')
        interaction_name = f'{feat_i_name}*{feat_j_name}'
        
        # Stage 1: No features (baseline with only other features)
        # Create feature matrix excluding features i and j
        other_features_mask = np.ones(X_train_scaled.shape[1], dtype=bool)
        other_features_mask[feat_i_idx] = False
        other_features_mask[feat_j_idx] = False
        
        X_train_stage1 = X_train_scaled[:, other_features_mask]
        X_test_stage1 = X_test_scaled[:, other_features_mask]
        
        # Stage 2: Linear features (other features + i + j)
        X_train_stage2 = X_train_scaled.copy()
        X_test_stage2 = X_test_scaled.copy()
        
        # Stage 3: Linear + interaction (other features + i + j + i*j)
        interaction_train = X_train_scaled[:, feat_i_idx] * X_train_scaled[:, feat_j_idx]
        interaction_test = X_test_scaled[:, feat_i_idx] * X_test_scaled[:, feat_j_idx]
        
        X_train_stage3 = np.column_stack([X_train_scaled, interaction_train])
        X_test_stage3 = np.column_stack([X_test_scaled, interaction_test])
        
        # Fit models for each stage
        stage_metrics = []
        for stage, (X_tr, X_te, stage_name) in enumerate([
            (X_train_stage1, X_test_stage1, "no_features"),
            (X_train_stage2, X_test_stage2, "linear_features"), 
            (X_train_stage3, X_test_stage3, "linear_plus_interaction")
        ], 1):
            
            # Handle edge case where no features remain
            if X_tr.shape[1] == 0:
                from sklearn.dummy import DummyClassifier, DummyRegressor
                if task_type == "classification":
                    model = DummyClassifier(strategy='most_frequent')
                else:
                    model = DummyRegressor(strategy='mean')
                model.fit(X_tr, y_train)
            else:
                if task_type == "classification":
                    model = LogisticRegression(random_state=42, max_iter=1000)
                else:
                    from sklearn.linear_model import LinearRegression
                    model = LinearRegression()
                model.fit(X_tr, y_train)
            
            # Calculate metrics
            if task_type == "classification":
                y_pred_binary = model.predict(X_te)
                
                if hasattr(model, 'predict_proba') and X_tr.shape[1] > 0:
                    y_pred_proba = model.predict_proba(X_te)[:, 1]
                else:
                    y_pred_proba = y_pred_binary.astype(float)
                
                try:
                    auc = roc_auc_score(y_test, y_pred_proba)
                except:
                    auc = np.nan
                    
                f1 = f1_score(y_test, y_pred_binary)
                accuracy = accuracy_score(y_test, y_pred_binary)
                precision = precision_score(y_test, y_pred_binary, zero_division=0)
                recall = recall_score(y_test, y_pred_binary, zero_division=0)
                
                metrics = {
                    'auc': auc, 'f1': f1, 'accuracy': accuracy,
                    'precision': precision, 'recall': recall
                }
                
            else:  # regression
                y_pred = model.predict(X_te)
                
                from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
                
                mse = mean_squared_error(y_test, y_pred)
                rmse = np.sqrt(mse)
                mae = mean_absolute_error(y_test, y_pred)
                r2 = r2_score(y_test, y_pred)
                
                def mean_absolute_percentage_error(y_true, y_pred):
                    mask = y_true != 0
                    if np.sum(mask) == 0:
                        return np.nan
                    return np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
                
                mape = mean_absolute_percentage_error(y_test, y_pred)
                
                metrics = {
                    'r2': r2, 'rmse': rmse, 'mae': mae, 'mse': mse, 'mape': mape
                }
            
            # Store detailed metrics
            all_metrics_list.append({
                'interaction_idx': idx,
                'interaction_name': interaction_name,
                'stage': stage,
                'stage_name': stage_name,
                'n_features': X_tr.shape[1],
                'feat_i_idx': feat_i_idx,
                'feat_j_idx': feat_j_idx,
                **metrics
            })
            
            stage_metrics.append(metrics[primary_metric])
        
        # Calculate performance gains
        perf_stage1, perf_stage2, perf_stage3 = stage_metrics
        
        # g1: gain from adding linear features (stage2 - stage1)
        # g2: gain from adding interaction (stage3 - stage2)
        g1 = perf_stage2 - perf_stage1
        g2 = perf_stage3 - perf_stage2
        
        # Calculate difference and ratio
        g_diff = g2 - g1
        g_ratio = g2 / g1 if g1 != 0 else np.inf if g2 > 0 else (np.NINF if g2 < 0 else np.nan)
        
        # Store comparison results
        comparison_list.append({
            'interaction_idx': idx,
            'interaction_name': interaction_name,
            'feat_i_idx': feat_i_idx,
            'feat_j_idx': feat_j_idx,
            'feat_i_name': feat_i_name,
            'feat_j_name': feat_j_name,
            'perf_stage1_no_features': perf_stage1,
            'perf_stage2_linear': perf_stage2,
            'perf_stage3_interaction': perf_stage3,
            'g1_linear_gain': g1,
            'g2_interaction_gain': g2,
            'g_difference': g_diff,
            'g_ratio': g_ratio,
            'primary_metric': primary_metric
        })
    
    comparison_df = pd.DataFrame(comparison_list)
    all_metrics_df = pd.DataFrame(all_metrics_list)
    
    return comparison_df, all_metrics_df


def plot_three_stage_curves(comparison_df, top_k=10, metric_type='gains', figsize=(12, 8)):
    """
    Plot curves for three-stage evaluation results.
    
    Parameters:
        comparison_df: DataFrame from three_stage_interaction_evaluation
        top_k: Number of top interactions to plot (uses original ranking from results_df)
        metric_type: Type of curves to plot
                    - 'gains': Plot g1 and g2 gains
                    - 'performance': Plot absolute performance for all three stages
                    - 'ratios': Plot g_ratio and g_difference
        figsize: Figure size tuple
    """
    
    # Use the original ranking from results_df (interaction_idx is already sorted)
    # Take top k interactions based on the original ranking
    top_interactions = comparison_df.head(top_k).copy()
    
    if metric_type == 'gains':
        # Plot g1 and g2 gains
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        
        x_pos = range(len(top_interactions))
        
        ax.plot(x_pos, top_interactions['g1_linear_gain'], 
               marker='o', linewidth=2, markersize=6, 
               color='#1f77b4', alpha=0.8, label='g1 (Linear Gain)')
        ax.plot(x_pos, top_interactions['g2_interaction_gain'], 
               marker='s', linewidth=2, markersize=6, 
               color='#ff7f0e', alpha=0.8, label='g2 (Interaction Gain)')
        
        ax.set_xlabel('Interaction Rank (original ranking)', fontsize=12)
        ax.set_ylabel(f'Performance Gain ({top_interactions["primary_metric"].iloc[0]})', fontsize=12)
        ax.set_title(f'Top {top_k} Interactions: Linear vs Interaction Gains', fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.legend()
        
        # Add interaction names as x-tick labels (rotated)
        ax.set_xticks(x_pos)
        ax.set_xticklabels(top_interactions['interaction_name'], rotation=45, ha='right')
        
    elif metric_type == 'performance':
        # Plot absolute performance for all three stages
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        
        x_pos = range(len(top_interactions))
        
        ax.plot(x_pos, top_interactions['perf_stage1_no_features'], 
               marker='v', linewidth=2, markersize=6, 
               color='#d62728', alpha=0.8, label='Stage 1: No Features')
        ax.plot(x_pos, top_interactions['perf_stage2_linear'], 
               marker='o', linewidth=2, markersize=6, 
               color='#1f77b4', alpha=0.8, label='Stage 2: Linear Features')
        ax.plot(x_pos, top_interactions['perf_stage3_interaction'], 
               marker='s', linewidth=2, markersize=6, 
               color='#2ca02c', alpha=0.8, label='Stage 3: + Interaction')
        
        ax.set_xlabel('Interaction Rank (original ranking)', fontsize=12)
        ax.set_ylabel(f'Absolute Performance ({top_interactions["primary_metric"].iloc[0]})', fontsize=12)
        ax.set_title(f'Top {top_k} Interactions: Three-Stage Performance', fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.legend()
        
        # Add interaction names as x-tick labels (rotated)
        ax.set_xticks(x_pos)
        ax.set_xticklabels(top_interactions['interaction_name'], rotation=45, ha='right')
        
    elif metric_type == 'ratios':
        # Plot g_ratio and g_difference
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize)
        
        x_pos = range(len(top_interactions))
        
        # Plot g_ratio (with handling for inf values)
        g_ratios = top_interactions['g_ratio'].copy()
        g_ratios = np.clip(g_ratios, -10, 10)  # Clip extreme values for visualization
        
        ax1.plot(x_pos, g_ratios, 
                marker='o', linewidth=2, markersize=6, 
                color='#9467bd', alpha=0.8)
        ax1.set_ylabel('g2/g1 Ratio (clipped)', fontsize=12)
        ax1.set_title(f'Top {top_k} Interactions: Gain Ratios and Differences', fontsize=14, fontweight='bold')
        ax1.grid(True, alpha=0.3)
        ax1.axhline(y=1, color='red', linestyle='--', alpha=0.5, label='Equal gains (ratio=1)')
        ax1.legend()
        
        # Plot g_difference
        ax2.plot(x_pos, top_interactions['g_difference'], 
                marker='s', linewidth=2, markersize=6, 
                color='#8c564b', alpha=0.8)
        ax2.set_xlabel('Interaction Rank (original ranking)', fontsize=12)
        ax2.set_ylabel('g2 - g1 Difference', fontsize=12)
        ax2.grid(True, alpha=0.3)
        ax2.axhline(y=0, color='red', linestyle='--', alpha=0.5, label='Equal gains (diff=0)')
        ax2.legend()
        
        # Add interaction names as x-tick labels (rotated) to bottom subplot
        ax2.set_xticks(x_pos)
        ax2.set_xticklabels(top_interactions['interaction_name'], rotation=45, ha='right')
        
    else:
        raise ValueError(f"metric_type must be one of ['gains', 'performance', 'ratios'], got '{metric_type}'")
    
    # Clean styling
    if metric_type != 'ratios':
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_linewidth(0.5)
        ax.spines['bottom'].set_linewidth(0.5)
    else:
        for a in [ax1, ax2]:
            a.spines['top'].set_visible(False)
            a.spines['right'].set_visible(False)
            a.spines['left'].set_linewidth(0.5)
            a.spines['bottom'].set_linewidth(0.5)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print(f"\nThree-Stage Evaluation Summary (Top {top_k} interactions by original ranking):")
    print("-" * 70)
    print(f"Average g1 (linear gain): {top_interactions['g1_linear_gain'].mean():.4f}")
    print(f"Average g2 (interaction gain): {top_interactions['g2_interaction_gain'].mean():.4f}")
    print(f"Average g2/g1 ratio: {top_interactions['g_ratio'].replace([np.inf, -np.inf], np.nan).mean():.4f}")
    print(f"Average g2-g1 difference: {top_interactions['g_difference'].mean():.4f}")
    
    # Count how many interactions have g2 > g1
    g2_better_count = (top_interactions['g2_interaction_gain'] > top_interactions['g1_linear_gain']).sum()
    print(f"Interactions where g2 > g1: {g2_better_count}/{len(top_interactions)} ({100*g2_better_count/len(top_interactions):.1f}%)")
    
    return fig


def plot_multiple_three_stage_curves(X_train, X_test, y_train, y_test, results_dfs, names, 
                                    max_interactions=100, task_type="auto", 
                                    metric_type='g2', figsize=(12, 8), plot_steps=None):
    """
    Plot multiple three-stage evaluation curves for comparison across different methods.
    
    Parameters:
        X_train, X_test: Training and test feature matrices
        y_train, y_test: Training and test target vectors
        results_dfs: List of DataFrames with interaction rankings from different methods
        names: List of names for each method (same length as results_dfs)
        max_interactions: Maximum number of interactions to test for each method
        task_type: "classification", "regression", or "auto" - determines model type and metrics
        metric_type: Type of metric to plot
                    - 'g1': Linear gain curves
                    - 'g2': Interaction gain curves  
                    - 'g_ratio': g2/g1 ratio curves
                    - 'g_difference': g2-g1 difference curves
        figsize: Figure size tuple
        plot_steps: Maximum number of steps to plot on x-axis (None = plot all steps)
    
    Returns:
        fig: matplotlib figure
        comparison_dfs: List of comparison DataFrames from three_stage_interaction_evaluation
    """
    
    # Validate inputs
    if len(results_dfs) != len(names):
        raise ValueError(f"Number of DataFrames ({len(results_dfs)}) must match number of names ({len(names)})")
    
    # Valid metric types
    valid_metrics = ['g1', 'g2', 'g_ratio', 'g_difference']
    if metric_type not in valid_metrics:
        raise ValueError(f"metric_type must be one of {valid_metrics}, got '{metric_type}'")
    
    # Auto-detect task type if needed
    if task_type == "auto":
        task_type = auto_detect_task_type(y_train)
        print(f"Auto-detected task type: {task_type}")
    
    print(f"Running three-stage evaluation for {len(results_dfs)} methods...")
    
    # Run three-stage evaluation for each method
    comparison_dfs = []
    for i, (results_df, name) in enumerate(zip(results_dfs, names)):
        print(f"\nEvaluating method {i+1}/{len(results_dfs)}: {name}")
        comparison_df, _ = three_stage_interaction_evaluation(
            X_train, X_test, y_train, y_test, results_df, 
            max_interactions=max_interactions, task_type=task_type
        )
        comparison_dfs.append(comparison_df)
    
    # Create the plot
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    
    # Color palette for multiple curves
    colors = plt.cm.Set1(np.linspace(0, 1, len(comparison_dfs)))
    
    # Determine the metric column name and y-label
    metric_mapping = {
        'g1': ('g1_linear_gain', 'Linear Gain (g1)'),
        'g2': ('g2_interaction_gain', 'Interaction Gain (g2)'),
        'g_ratio': ('g_ratio', 'g2/g1 Ratio'),
        'g_difference': ('g_difference', 'g2 - g1 Difference')
    }
    
    metric_col, y_label = metric_mapping[metric_type]
    
    # Plot each curve
    for i, (comparison_df, name) in enumerate(zip(comparison_dfs, names)):
        # Filter data based on plot_steps if specified
        if plot_steps is not None:
            plot_data = comparison_df[comparison_df['interaction_idx'] < plot_steps]
        else:
            plot_data = comparison_df
        
        # Handle special cases for ratio plotting
        if metric_type == 'g_ratio':
            # Clip extreme values for better visualization
            y_values = plot_data[metric_col].copy()
            y_values = np.clip(y_values, -10, 10)
        else:
            y_values = plot_data[metric_col]
            
        ax.plot(plot_data['interaction_idx'], y_values, 
               marker='o', linewidth=2, markersize=4, 
               color=colors[i], alpha=0.8, label=name)
    
    # Styling
    ax.set_xlabel('Interaction Rank', fontsize=12)
    ax.set_ylabel(f'{y_label} ({comparison_dfs[0]["primary_metric"].iloc[0]})', fontsize=12)
    ax.set_title(f'{y_label} Comparison: Multiple Methods', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    
    # Add reference lines for ratio and difference plots
    if metric_type == 'g_ratio':
        ax.axhline(y=1, color='red', linestyle='--', alpha=0.5, label='Equal gains (ratio=1)')
        ax.set_ylabel(f'{y_label} (clipped to ±10)', fontsize=12)
    elif metric_type == 'g_difference':
        ax.axhline(y=0, color='red', linestyle='--', alpha=0.5, label='Equal gains (diff=0)')
    
    # Set x-axis limits based on plot_steps
    if plot_steps is not None:
        ax.set_xlim(0, plot_steps-1)
    
    # Add legend
    ax.legend(loc='best', frameon=True, fancybox=True, shadow=True)
    
    # Clean styling
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(0.5)
    ax.spines['bottom'].set_linewidth(0.5)
    
    plt.tight_layout()
    plt.show()
    
    # Print comparison summary
    print(f"\n{y_label} Comparison Summary:")
    print("-" * 60)
    
    for i, (comparison_df, name) in enumerate(zip(comparison_dfs, names)):
        # Calculate statistics for the plotted range
        if plot_steps is not None:
            summary_data = comparison_df[comparison_df['interaction_idx'] < plot_steps]
        else:
            summary_data = comparison_df
            
        if metric_type == 'g_ratio':
            # Handle inf values for ratio statistics
            finite_values = summary_data[metric_col].replace([np.inf, -np.inf], np.nan).dropna()
            if len(finite_values) > 0:
                mean_val = finite_values.mean()
                median_val = finite_values.median()
                print(f"{name}:")
                print(f"  Mean {metric_type}: {mean_val:.4f} | Median: {median_val:.4f} | Count: {len(finite_values)}")
            else:
                print(f"{name}: No finite values")
        else:
            mean_val = summary_data[metric_col].mean()
            median_val = summary_data[metric_col].median()
            std_val = summary_data[metric_col].std()
            
            print(f"{name}:")
            print(f"  Mean {metric_type}: {mean_val:.4f} | Median: {median_val:.4f} | Std: {std_val:.4f}")
            
        # Count positive vs negative for g2, g_difference
        if metric_type in ['g2', 'g_difference']:
            positive_count = (summary_data[metric_col] > 0).sum()
            total_count = len(summary_data)
            print(f"  Positive {metric_type}: {positive_count}/{total_count} ({100*positive_count/total_count:.1f}%)")
        
        # For g_ratio, count where g2 > g1
        elif metric_type == 'g_ratio':
            g2_better_count = (summary_data['g2_interaction_gain'] > summary_data['g1_linear_gain']).sum()
            total_count = len(summary_data)
            print(f"  g2 > g1: {g2_better_count}/{total_count} ({100*g2_better_count/total_count:.1f}%)")
    
    return fig, comparison_dfs


def interaction_significance_evaluation(X_train, X_test, y_train, y_test, results_df, 
                                      max_interactions=100, task_type="auto", alpha=0.05):
    """
    Evaluate statistical significance of interaction terms.
    
    For each interaction i*j:
    - Fit model with all linear features + specific interaction i*j
    - Test statistical significance of the interaction coefficient
    - Return p-values and significance indicators
    
    Parameters:
        X_train, X_test: Training and test feature matrices
        y_train, y_test: Training and test target vectors
        results_df: DataFrame with interaction rankings (from get_improved_interactions)
        max_interactions: Maximum number of interactions to test
        task_type: "classification", "regression", or "auto" - determines model type
        alpha: Significance level (default 0.05)
    
    Returns:
        significance_df: DataFrame with p-values and significance for each interaction
    """
    
    # Auto-detect task type if needed
    if task_type == "auto":
        task_type = auto_detect_task_type(y_train)
        print(f"Auto-detected task type: {task_type}")
    elif task_type == "classification":
        # Check if labels are actually continuous
        unique_labels = np.unique(y_train)
        if len(unique_labels) > 10 or not np.all(np.isin(unique_labels, [0, 1])):
            print(f"Warning: task_type='classification' but labels appear continuous. Unique values: {len(unique_labels)}")
            print("Consider using task_type='regression' for continuous labels.")
    
    # Prepare base features (standardized)
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # Determine number of interactions to test
    n_interactions = len(results_df) if max_interactions is None else min(max_interactions, len(results_df))
    
    print(f"Testing statistical significance of {n_interactions} interactions for {task_type} (alpha={alpha})...")
    
    # Initialize results storage
    significance_list = []
    
    # Loop through each interaction
    for idx in tqdm(range(n_interactions), desc="Testing significance"):
        row = results_df.iloc[idx]
        feat_i_idx = int(row['i'])
        feat_j_idx = int(row['j'])
        feat_i_name = row.get('feature_i', f'x_{feat_i_idx}')
        feat_j_name = row.get('feature_j', f'x_{feat_j_idx}')
        interaction_name = f'{feat_i_name}*{feat_j_name}'
        
        # Create interaction term
        interaction_train = X_train_scaled[:, feat_i_idx] * X_train_scaled[:, feat_j_idx]
        interaction_test = X_test_scaled[:, feat_i_idx] * X_test_scaled[:, feat_j_idx]
        
        # Create feature matrix: all linear features + specific interaction
        X_train_with_interaction = np.column_stack([X_train_scaled, interaction_train])
        X_test_with_interaction = np.column_stack([X_test_scaled, interaction_test])
        
        # Fit model and get significance
        try:
            if task_type == "classification":
                # For classification, use logistic regression
                model = LogisticRegression(random_state=42, max_iter=1000)
                model.fit(X_train_with_interaction, y_train)
                
                # Get coefficient and standard error for the interaction term
                interaction_coef = model.coef_[0][-1]  # Last coefficient is the interaction
                
                # For logistic regression, we need to calculate p-values manually
                # Using Wald test approximation
                from scipy import stats
                
                # Calculate Fisher Information Matrix approximation
                y_pred_proba = model.predict_proba(X_train_with_interaction)[:, 1]
                W = np.diag(y_pred_proba * (1 - y_pred_proba))
                
                try:
                    # Calculate covariance matrix
                    XWX = X_train_with_interaction.T @ W @ X_train_with_interaction
                    cov_matrix = np.linalg.inv(XWX)
                    std_error = np.sqrt(cov_matrix[-1, -1])  # Standard error of interaction term
                    
                    # Wald test statistic
                    z_stat = interaction_coef / std_error
                    p_value = 2 * (1 - stats.norm.cdf(np.abs(z_stat)))
                    
                except (np.linalg.LinAlgError, ValueError):
                    # If matrix inversion fails, set p-value to NaN
                    p_value = np.nan
                    std_error = np.nan
                    z_stat = np.nan
                    
            else:  # regression
                # For regression, use linear regression with statsmodels for p-values
                try:
                    import statsmodels.api as sm
                    
                    # Add intercept
                    X_train_sm = sm.add_constant(X_train_with_interaction)
                    
                    # Fit OLS model
                    model = sm.OLS(y_train, X_train_sm).fit()
                    
                    # Get interaction coefficient info (last coefficient, excluding intercept)
                    interaction_coef = model.params[-1]
                    std_error = model.bse[-1]
                    p_value = model.pvalues[-1]
                    t_stat = model.tvalues[-1]
                    
                except ImportError:
                    # Fallback if statsmodels not available
                    print("Warning: statsmodels not available. Using sklearn LinearRegression without p-values.")
                    from sklearn.linear_model import LinearRegression
                    
                    model = LinearRegression()
                    model.fit(X_train_with_interaction, y_train)
                    interaction_coef = model.coef_[-1]
                    
                    # Cannot calculate p-values without statsmodels
                    p_value = np.nan
                    std_error = np.nan
                    t_stat = np.nan
                    
        except Exception as e:
            print(f"Error fitting model for interaction {interaction_name}: {e}")
            interaction_coef = np.nan
            p_value = np.nan
            std_error = np.nan
            if task_type == "classification":
                z_stat = np.nan
            else:
                t_stat = np.nan
        
        # Determine significance
        is_significant = p_value < alpha if not np.isnan(p_value) else False
        
        # Store results
        result_dict = {
            'interaction_idx': idx,
            'interaction_name': interaction_name,
            'feat_i_idx': feat_i_idx,
            'feat_j_idx': feat_j_idx,
            'feat_i_name': feat_i_name,
            'feat_j_name': feat_j_name,
            'coefficient': interaction_coef,
            'std_error': std_error,
            'p_value': p_value,
            'is_significant': is_significant,
            'alpha': alpha,
            'task_type': task_type
        }
        
        # Add test statistic based on task type
        if task_type == "classification":
            result_dict['z_statistic'] = z_stat if 'z_stat' in locals() else np.nan
        else:
            result_dict['t_statistic'] = t_stat if 't_stat' in locals() else np.nan
            
        significance_list.append(result_dict)
    
    significance_df = pd.DataFrame(significance_list)
    
    # Print summary
    if len(significance_df) > 0:
        total_tests = len(significance_df)
        valid_tests = significance_df['p_value'].notna().sum()
        significant_count = significance_df['is_significant'].sum()
        
        print(f"\nSignificance Testing Summary:")
        print(f"Total interactions tested: {total_tests}")
        print(f"Valid p-values obtained: {valid_tests}")
        print(f"Significant interactions (α={alpha}): {significant_count}/{valid_tests} ({100*significant_count/valid_tests:.1f}%)")
        
        if valid_tests > 0:
            mean_p_value = significance_df['p_value'].mean()
            median_p_value = significance_df['p_value'].median()
            print(f"Mean p-value: {mean_p_value:.4f}")
            print(f"Median p-value: {median_p_value:.4f}")
    
    return significance_df


def plot_cumulative_significance_curves(X_train, X_test, y_train, y_test, results_dfs, names,
                                      max_interactions=100, task_type="auto", alpha=0.05, 
                                      figsize=(12, 8), plot_steps=None, title=None):
    """
    Plot cumulative significance curves for multiple methods.
    
    Parameters:
        X_train, X_test: Training and test feature matrices
        y_train, y_test: Training and test target vectors
        results_dfs: List of DataFrames with interaction rankings from different methods
        names: List of names for each method (same length as results_dfs)
        max_interactions: Maximum number of interactions to test for each method
        task_type: "classification", "regression", or "auto" - determines model type
        alpha: Significance level (default 0.05)
        figsize: Figure size tuple
        plot_steps: Maximum number of steps to plot on x-axis (None = plot all steps)
        title: Overall title for the whole figure (default None)
    
    Returns:
        fig: matplotlib figure
        significance_dfs: List of significance DataFrames from interaction_significance_evaluation
    """
    
    # Validate inputs
    if len(results_dfs) != len(names):
        raise ValueError(f"Number of DataFrames ({len(results_dfs)}) must match number of names ({len(names)})")
    
    # Auto-detect task type if needed
    if task_type == "auto":
        task_type = auto_detect_task_type(y_train)
        print(f"Auto-detected task type: {task_type}")
    
    print(f"Running significance evaluation for {len(results_dfs)} methods...")
    
    # Run significance evaluation for each method
    significance_dfs = []
    for i, (results_df, name) in enumerate(zip(results_dfs, names)):
        print(f"\nEvaluating method {i+1}/{len(results_dfs)}: {name}")
        significance_df = interaction_significance_evaluation(
            X_train, X_test, y_train, y_test, results_df,
            max_interactions=max_interactions, task_type=task_type, alpha=alpha
        )
        significance_dfs.append(significance_df)
    
    # Create the plot
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize)
    
    # Color palette for multiple curves
    colors = plt.cm.Set1(np.linspace(0, 1, len(significance_dfs)))
    
    # Plot cumulative significance curves
    for i, (significance_df, name) in enumerate(zip(significance_dfs, names)):
        # Filter data based on plot_steps if specified
        if plot_steps is not None:
            plot_data = significance_df[significance_df['interaction_idx'] < plot_steps]
        else:
            plot_data = significance_df
        
        # Calculate cumulative significant count
        valid_mask = plot_data['p_value'].notna()
        cumulative_significant = plot_data[valid_mask]['is_significant'].cumsum()
        cumulative_total = np.arange(1, len(cumulative_significant) + 1)
        interaction_indices = plot_data[valid_mask]['interaction_idx'].values
        
        # Plot cumulative count
        ax1.plot(interaction_indices, cumulative_significant, 
                marker='o', linewidth=2, markersize=4,
                color=colors[i], alpha=0.8, label=name)
        
        # Plot cumulative proportion
        cumulative_proportion = cumulative_significant / cumulative_total
        ax2.plot(interaction_indices, cumulative_proportion,
                marker='s', linewidth=2, markersize=4,
                color=colors[i], alpha=0.8, label=name)
    
    # Styling for cumulative count plot
    ax1.set_ylabel('Cumulative Hit Count', fontsize=16)
    ax1.grid(True, alpha=0.3)
    ax1.legend(loc='best')
    
    # Styling for cumulative proportion plot
    ax2.set_xlabel('Interaction Rank', fontsize=16)
    ax2.set_ylabel('Cumulative Hit Ratio', fontsize=16)
    ax2.grid(True, alpha=0.3)
    ax2.legend(loc='best')
    ax2.set_ylim(0, 1)
    
    # Add overall title if provided
    if title is not None:
        fig.suptitle(title, fontsize=18, fontweight='bold')
    
    # Set x-axis limits based on plot_steps
    if plot_steps is not None:
        ax1.set_xlim(0, plot_steps-1)
        ax2.set_xlim(0, plot_steps-1)
    
    # Clean styling
    for ax in [ax1, ax2]:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_linewidth(0.5)
        ax.spines['bottom'].set_linewidth(0.5)
        # Set x-axis to show only integers
        ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
    
    plt.tight_layout()
    plt.show()
    
    # Print comparison summary
    print(f"\nSignificance Comparison Summary (α={alpha}):")
    print("-" * 70)
    
    for i, (significance_df, name) in enumerate(zip(significance_dfs, names)):
        # Calculate statistics for the plotted range
        if plot_steps is not None:
            summary_data = significance_df[significance_df['interaction_idx'] < plot_steps]
        else:
            summary_data = significance_df
            
        total_tests = len(summary_data)
        valid_tests = summary_data['p_value'].notna().sum()
        significant_count = summary_data['is_significant'].sum()
        
        if valid_tests > 0:
            significant_prop = significant_count / valid_tests
            mean_p_value = summary_data['p_value'].mean()
            
            print(f"{name}:")
            print(f"  Significant: {significant_count}/{valid_tests} ({100*significant_prop:.1f}%)")
            print(f"  Mean p-value: {mean_p_value:.4f}")
            print(f"  Valid tests: {valid_tests}/{total_tests}")
        else:
            print(f"{name}: No valid p-values obtained")
    
    return fig, significance_dfs


